#!/usr/bin/python

# make the clustering class-specific

import sys,os,signal,traceback
import fcntl
import matplotlib,tables
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
from pylab import *
from multiprocessing import Pool
import ocrolib
from ocrolib.ligatures import lig
from ocrolib import morph,linerec,improc,h5utils
from ocrolib.toplevel import *

import warnings,numpy
warnings.simplefilter('ignore',numpy.RankWarning)
warnings.filterwarnings('error','.*invalid value.*')

signal.signal(signal.SIGINT,lambda *args:sys.exit(1))

# these options control alignment
import argparse
parser = argparse.ArgumentParser(description = """
Computes recognition lattices for text lines.  Also displays the bestpath
result (recognition result without language model).

Inputs: textline.png
Outputs: textline.lattice, textline.rseg.png
""")
parser.add_argument("-X","--exec",dest="execute",help="execute before anything else (usually used for imports)",default="None")
#parser.add_argument("-s","--segmenter",default="lineseg.DPSegmentLine()",help="segmenter (%(default)s)")
parser.add_argument("-s","--segmenter",default="lineseg.ComboSegmentLine()",help="segmenter (%(default)s)")
parser.add_argument("-m","--model",default=ocrolib.default.model,help="character model (%(default)s)")
parser.add_argument("-w","--whitespace",default=ocrolib.default.space,help="space model (%(default)s)")
parser.add_argument("-Q","--parallel",type=int,default=1,help="number of parallel processes to use (%(default)s)")
parser.add_argument('-q','--quiet',action="store_true",help="don't output progress info")
parser.add_argument('-S','--sizemode',default=None,help="how to resize characters for extraction: perchar, linerel, lineabs (%(default)s)")
parser.add_argument('-x','--exitonerr',action="store_true",help='time to wait after displaying results')
parser.add_argument('-e','--lineest',default=ocrolib.default.lineest,help="line geometry model (%(default)s)")
parser.add_argument("--extract",default=None,help="extract characters for cmodel training")
parser.add_argument("--noglob",action="store_true",help="don't perform expansion on the arguments")
parser.add_argument('--hdfappend',action="store_true",help="append to any exiting HDF5 file (this is multi-processing safe)")
parser.add_argument('--baselinedegree',type=int,default=1,help="polynomial degree used for modeling baseline")
parser.add_argument('--writebestpath',action="store_true",help="write the best path as a recognition result, without a language model (this is only for debugging)")
parser.add_argument('--borderclean',default=8,type=int,help="remove components that are contained within a margin that's this large (%(default)s)")

parser.add_argument("--show",help="show progress",action="store_true")
parser.add_argument('--delay',type=float,default=9999,help='time to wait after displaying results')
parser.add_argument("files",default=[],nargs='*',help="input lines")
args = parser.parse_args()
if not args.noglob: 
    args.files = ocrolib.glob_all(args.files)
if len(args.files)==0:
    parser.print_help()
    sys.exit(0)

# make these parameters eventually
charsize = 32
target_xheight = 16

exec args.execute

###
### This first option doesn't actually recognize anything, it extracts
### characters for cmodel training.  The reason it is in `ocropus-linerec`
### is because this is the only program that needs to know about normalization.
###

class Hdf5Writer:
    def __init__(self,fname,mode="w",size=(charsize,charsize)):
        self.mode = mode
        self.size = size
        self.fname = fname
    def __enter__(self):
        self.fd = os.open(self.fname+".lock",os.O_RDWR|os.O_APPEND|os.O_CREAT)
        fcntl.lockf(self.fd,fcntl.LOCK_EX)
        self.h5 = tables.openFile(self.fname,self.mode)
        self.create()
        return self
    def __exit__(self,*args):
        self.h5.close()
        del self.h5
        fcntl.lockf(self.fd,fcntl.LOCK_UN)
        os.close(self.fd)
        os.system("sync")
        os.system("sleep 3")
    def create(self):
        from tables import Float32Atom,Int64Atom,Filters
        h5 = self.h5
        if "patches" not in dir(h5.root):
            h5.createEArray(h5.root,'patches',Float32Atom(),shape=(0,)+self.size,filters=Filters(9))
            h5.createEArray(h5.root,'classes',Int64Atom(),shape=(0,),filters=tables.Filters(9))
            #h5.createVLArray(h5.root,'files',StringAtom(120),filters=Filters(9))
            #h5.createEArray(h5.root,'bboxes',Float32Atom(shape=(4,)),shape=(0,),filters=Filters(9))
    def insert(self,image,cls,cost=0.0,count=0,fname="",lgeo=None,bbox=(-1,-1,-1,-1)):
        assert image.shape==self.size,"wrong image shape: %s"%(image.shape,)
        h5 = self.h5
        h5.root.patches.append([image])
        h5.root.classes.append([lig.ord(cls)])
        #h5.root.files.append(fname)
        #h5.root.bboxes.append([array(bbox,'f')])

class CharResizer:
    def __init__(self,sizemode,target_xheight):
        assert isinstance(sizemode,str)
        self.sizemode = sizemode
        assert target_xheight>=8 and target_xheight<=200
        self.target_xheight = target_xheight
    def load(self,fname):
        base = ocrolib.allsplitext(fname)[0]
        lname = base+".bin.png"
        limage = ocrolib.read_image_gray(lname)
        return self.set(limage)
    def set(self,limage):
        self.limage = limage
        params = emodel.lineParameters(1-limage,order=args.baselinedegree)
        self.avgbaseline,self.xheight,self.blp,self.xlp = params
        return self
    def show(self):
        imshow(self.limage)
        xs = arange(self.limage.shape[1])
        plot(xs,polyval(self.blp,xs))
        plot(xs,polyval(self.xlp,xs))
    def baselineAsText(self):
        h = len(self.limage)
        params = -array(self.blp)
        params[-1] += h
        return " ".join("%.4f"%x for x in params)
    def resize(self,image,bbox):
        if self.sizemode=="linerel":
            x = bbox[1].start
            baseline = polyval(self.blp,x)
            xline = polyval(self.xlp,x)
            if xline>=baseline: 
                raise BadImage("xline>=baseline %d %d"%(xline,baseline))
            options = dict(bar=(xline,baseline))
            scale = self.target_xheight*1.0/self.xheight
            image = improc.line_normalize(1.0*image,scale=scale,**options)
        elif self.sizemode=="perchar":
            try:
                image = improc.classifier_normalize(1.0*image)
            except:
                traceback.print_exc()
                raise BadInput("classifier_normalize failed, skipping")
        elif self.sizemode=="lineabs":
            raise Unimplemented("lineabs")
        else:
            raise Exception("unknown sizemode:"+self.sizemode)
        return image

if args.extract is not None:
    if args.quiet: print "extracting..."
    emodel = ocrolib.ocropus_find_file(args.lineest)
    print "loading",emodel
    emodel = ocrolib.load_component(emodel)
    print "got",emodel
    sizemode = args.sizemode or "linerel"
    print "sizemode",sizemode

    hdfmode = "a" if args.hdfappend else "w"
    with Hdf5Writer(args.extract,mode=hdfmode) as h5:
        h5utils.log(h5.h5," ".join(sys.argv))
        h5.h5.setNodeAttr("/","sizemode",sizemode)
        
    insertions = []
    for fname in args.files:
        try:
            base = ocrolib.allsplitext(fname)[0]
            gname = base+".aligned"
            if not os.path.exists(gname): 
                gname = base+".gt.txt"
            if not os.path.exists(gname): 
                print fname,"=EXTRACTED=","    *** NO ALIGNED TEXT ***",gname 
                continue
            cname = base+".cseg.png"
            if not os.path.exists(cname): 
                print fname,"=EXTRACTED=","    *** NO CSEG ***",cname
                continue
            rname = base+".rseg.png"
            if not os.path.exists(rname): 
                print fname,"=EXTRACTED=","    *** NO RSEG ***",rname
                continue
            gt = ocrolib.gt_explode(ocrolib.read_text(gname))
            if len(gt)==0: 
                print fname,"=EXTRACTED=","    *** EMPTY GT ***"
                continue
            if gt[-1]=="\n": gt = gt[:-1]
            cseg = ocrolib.read_line_segmentation(cname)
            rseg = ocrolib.read_line_segmentation(rname)
            csegs = linerec.extract_csegs(cseg)
            maxseg = amax([c.last for c in csegs])
            if maxseg!=len(gt):
                print fname,"=EXTRACTED=","    *** maxseg AND aligned lengths DIFFER***",len(gt),maxseg
                continue
            csegs = [c.replace(out=[(gt[c.first-1],0.0)]) for i,c in enumerate(csegs)]
            csegs = [c for c in csegs if c.out[0][0]!="~"]
            rsegs = linerec.extract_rsegs(rseg)
            misseg = linerec.extract_non_csegs(rsegs,csegs)
            misseg = [c.replace(out=[("~",0.0)]) for c in misseg]
            if args.show:
                print csegs
                ion(); gray(); clf()
                figure(1); ocrolib.showgrid([c.img for c in csegs[:100]],xlabels=[c.out[0][0] for c in csegs]) 
                figure(2); ocrolib.showgrid([m.img for m in misseg[:100]],xlabels=[m.out[0][0] for m in misseg]) 
                ginput(1,args.delay)
            # TODO optionally double-check against model here
            if not args.quiet: 
                print fname,"=EXTRACTED=",ocrolib.gt_implode(gt)
            resizer = CharResizer(sizemode,target_xheight).load(fname)
            if resizer.xheight<8 or resizer.xheight>100: # TODO make these arguments
                print "bad xheight:",xheight
                continue
            if args.show:
                clf()
                subplot(311); morph.showlabels(cseg)
                subplot(312); morph.showlabels(cseg)
                subplot(313); resizer.show()
                ginput(1,args.delay)
            for c in csegs+misseg:
                image = c.img
                image = resizer.resize(image,c.bbox)
                cls = c.out[0][0]
                insertions.append((image,cls))
        except ocrolib.RecognitionError as e:
            print str(e).replace("\n"," ")[:80]
            continue
        if len(insertions)>10000:
            with Hdf5Writer(args.extract,mode="a") as h5:
                for image,cls in insertions:
                    h5.insert(image=image,cls=cls)
            insertions = []
    with Hdf5Writer(args.extract,mode="a") as h5:
        for image,cls in insertions:
            h5.insert(image=image,cls=cls)
    sys.exit(0)

###
### This is the actual line recognizer
###


cmodel = ocrolib.ocropus_find_file(args.model)
print "loading",cmodel
cmodel = ocrolib.load_component(cmodel)
if not args.quiet: print "got",cmodel

sizemode = getattr(cmodel,"sizemode","perchar")
if not args.quiet: print "sizemode",sizemode

wmodel = ocrolib.ocropus_find_file(args.whitespace)
if not args.quiet: print "loading",wmodel
wmodel = ocrolib.load_component(wmodel)
if not args.quiet: print "got",wmodel

emodel = ocrolib.ocropus_find_file(args.lineest)
if not args.quiet: print "loading",emodel
emodel = ocrolib.load_component(emodel)
if not args.quiet: print "got",emodel

if not args.quiet: print "segmenter",args.segmenter
segmenter = eval(args.segmenter)
if not args.quiet: print "got",segmenter

def invert_image(image):
    """Invert the input image."""
    return amax(image)-image

def loutputs(image,floor=1e-6,keep_rejects=0):
    """Compute the negative log probability (cost) for classifications
    of the character image."""
    outputs = []
    outputs = cmodel.coutputs(image)
    outputs = [(cls,-log(max(p,floor))) for cls,p in outputs if "~" not in cls or keep_rejects]
    if len(outputs)<1: outputs = [("~",30)]
    return outputs

def make_connected(rsegs,insert=[("~",30.0)]):
    """Given a list of segmentation hypotheses, inserts
    reject classes for any segments that are not present.
    This reconnects the graph in case one of the previous
    steps has rejected some segment entirely."""
    if len(rsegs)<1: return
    transitions = [(r.first,r.last) for r in rsegs]
    lo = amin([r.first for r in rsegs])
    hi = amax([r.first for r in rsegs])
    for i in range(lo,hi+1):
        if (i,i) not in transitions:
            rsegs.append(linerec.Segment(first=i,last=i,img=zeros((1,1)),
                                         bbox=(slice(0,1),slice(0,1)),
                                         out=insert,sp=array([1.0,0])))
    return rsegs

if args.show: ion(); gray()

def process1(fname):
    try:
        if not args.quiet: print fname,"=RAW=",

        # read the image and display it
        
        image = ocrolib.read_image_gray(fname)

        if args.borderclean>0:
            h,w = image.shape
            mask = zeros((h,w),'i')
            d = minimum(args.borderclean,h//2-4)
            mask[d:-d,d:-d] = 1
            image = 1.0*invert_image(morph.keep_marked(image<ocrolib.midrange(image),mask))

        if args.show:
            figure(2); clf()
            figure(1); clf(); subplot(311)
            imshow(image)
            ginput(1,0.1)

        try:
            checktype(image,LINE)
        except:
            print "    *** DOESN'T SATISFY GEOMETRIC CONSTRAINS ON LINES, SKIPPED ***",image.shape
            return

        # generate character candidates

        rseg = segmenter.charseg(image)
        rseg = morph.renumber_by_xcenter(rseg)
        rsegs = linerec.extract_rsegs(rseg)
        if len(rsegs)<1:
            if args.quiet: print fname,"=RAW=",
            print "    *** NO RAW SEGMENTS ***"
            return

        # apply the resizer

        resizer = CharResizer(sizemode,target_xheight).set(image)
        rsegs = [r.replace(img=resizer.resize(r.img,r.bbox)) for r in rsegs]

        if args.show:
            subplot(311)
            imshow(image)
            xs = arange(image.shape[1])
            baseline = polyval(resizer.blp,xs)
            xline = polyval(resizer.xlp,xs)
            plot(baseline); plot(xline)

        if args.show:
            figure(1); subplot(312)
            morph.showlabels(rseg)
            ginput(1,0.1)

        # classify each character

        recognized = [r.replace(out=loutputs(r.img)) for r in rsegs]

        # compute whitespace probabilities
        
        wmodel.setLine(invert_image(image))
        recognized = [r.replace(sp=wmodel.classifySpace(r.bbox[1].stop)) for r in recognized]

        if args.show:
            figure(2)
            labels = [(r.out[0][0] if r.out and r.out[0][1]<1 else "_") for r in recognized]
            ocrolib.showgrid([r.img for r in recognized][:100],cols=20,xlabels=labels)
            ginput(1,0.1)
            
        # make sure the resulting graph is connected

        recognized = make_connected(recognized)

        # output the best path without a language model for debugging

        labels,costs,trans = linerec.bestpath(recognized,noreject=0)

        if labels is None:
            if args.quiet: print fname,"=RAW=",
            print "    *** FAILED (no bestpath) ***"
            # still continue
        elif not args.quiet: 
            print "".join(labels)

        base,_ = ocrolib.allsplitext(fname)

        # optionally, write best path without a language model
        # (this is intended mostly for debugging)

        if args.writebestpath and labels is not None:
            # we write both so that debugging tools still have the
            # raw text output after language modeling
            ocrolib.write_text(base+".raw.txt","".join(labels))
            ocrolib.write_text(base+".txt","".join(labels))

        # write the lattice and the raw segmentation
        
        with open(base+".lattice","w") as stream:
            linerec.write_lattice(stream,recognized)

        # write the raw segmentation

        ocrolib.write_line_segmentation(base+".rseg.png",rseg)

        # write line geometry information

        ocrolib.write_text(base+".xheight","%.1f"%resizer.xheight)
        ocrolib.write_text(base+".baseline",resizer.baselineAsText())
        
        if args.show:
            ginput(1,args.delay)
    except ocrolib.RecognitionError,e:
        print "    ***",fname,":",e,"***"
        return
    except:
        print "    *** ERROR IN",fname,"***"
        traceback.print_exc()
        if args.exitonerr: sys.exit(1)
        return

###
### top level: either run sequentially or in parallel under multiprocessing
###

print "recognizing",len(args.files),"files"
if args.show:
    args.parallel = 1
if args.parallel==1:
    for fname in args.files:
        process1(fname)
else:
    pool = Pool(processes=args.parallel)
    result = []
    for r in pool.imap_unordered(process1,args.files):
        result.append(r)
        if len(result)%100==0: print "==========",len(result),"of",len(args.files),"=========="
